{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## K近邻算法"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 引入依赖"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "# 这里直接引入sklearn里的数据集， iris 鸢尾花\n",
    "from sklearn.datasets import load_iris\n",
    "from sklearn.model_selection import train_test_split # 切分数据集为训练集和测试集\n",
    "from sklearn.metrics import accuracy_score # 用来计算分类预测的准确率"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 数据加载和预处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sepal length (cm)</th>\n",
       "      <th>sepal width (cm)</th>\n",
       "      <th>petal length (cm)</th>\n",
       "      <th>petal width (cm)</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>count</th>\n",
       "      <td>150.000000</td>\n",
       "      <td>150.000000</td>\n",
       "      <td>150.000000</td>\n",
       "      <td>150.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>5.843333</td>\n",
       "      <td>3.057333</td>\n",
       "      <td>3.758000</td>\n",
       "      <td>1.199333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>std</th>\n",
       "      <td>0.828066</td>\n",
       "      <td>0.435866</td>\n",
       "      <td>1.765298</td>\n",
       "      <td>0.762238</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>min</th>\n",
       "      <td>4.300000</td>\n",
       "      <td>2.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.100000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25%</th>\n",
       "      <td>5.100000</td>\n",
       "      <td>2.800000</td>\n",
       "      <td>1.600000</td>\n",
       "      <td>0.300000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50%</th>\n",
       "      <td>5.800000</td>\n",
       "      <td>3.000000</td>\n",
       "      <td>4.350000</td>\n",
       "      <td>1.300000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75%</th>\n",
       "      <td>6.400000</td>\n",
       "      <td>3.300000</td>\n",
       "      <td>5.100000</td>\n",
       "      <td>1.800000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>max</th>\n",
       "      <td>7.900000</td>\n",
       "      <td>4.400000</td>\n",
       "      <td>6.900000</td>\n",
       "      <td>2.500000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       sepal length (cm)  sepal width (cm)  petal length (cm)  \\\n",
       "count         150.000000        150.000000         150.000000   \n",
       "mean            5.843333          3.057333           3.758000   \n",
       "std             0.828066          0.435866           1.765298   \n",
       "min             4.300000          2.000000           1.000000   \n",
       "25%             5.100000          2.800000           1.600000   \n",
       "50%             5.800000          3.000000           4.350000   \n",
       "75%             6.400000          3.300000           5.100000   \n",
       "max             7.900000          4.400000           6.900000   \n",
       "\n",
       "       petal width (cm)  \n",
       "count        150.000000  \n",
       "mean           1.199333  \n",
       "std            0.762238  \n",
       "min            0.100000  \n",
       "25%            0.300000  \n",
       "50%            1.300000  \n",
       "75%            1.800000  \n",
       "max            2.500000  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "iris = load_iris()\n",
    "df = pd.DataFrame(data = iris.data,columns= iris.feature_names)\n",
    "df['class'] = iris.target\n",
    "df['class'] = df['class'].map({0:iris.target_names[0], 1:iris.target_names[1], 2:iris.target_names[2] })\n",
    "df.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(150, 4) (150, 1)\n"
     ]
    }
   ],
   "source": [
    "x = iris.data\n",
    "y = iris.target.reshape(-1, 1)\n",
    "print( x.shape, y.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(105, 4) (105, 1)\n",
      "(45, 4) (45, 1)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([1, 0, 4, 2, 3])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 划分训练集和测试集\n",
    "x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=35, stratify=y)\n",
    "print(x_train.shape, y_train.shape)\n",
    "print(x_test.shape, y_test.shape)\n",
    "\n",
    "dists = np.array([3,2,532, 2323,45])\n",
    "\n",
    "np.argsort(dists)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. 核心算法实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 距离函数的定义\n",
    "def l1_distance(a, b):\n",
    "    return np.sum(np.abs(a - b), axis=1)\n",
    "\n",
    "def l2_distance(a, b):\n",
    "    return np.sqrt( np.sum((a-b)**2, axis=1))\n",
    "\n",
    "# 分类器写法\n",
    "class KNN(object):\n",
    "    # 定义一个初始化方法, __init__ 是类的构造方法\n",
    "    def __init__(self, n_neighbors = 1, dist_func = l1_distance):\n",
    "        self.n_neighbors = n_neighbors\n",
    "        self.dist_func = dist_func\n",
    "        \n",
    "    # 训练模型的方法\n",
    "    def fit(self, x, y):\n",
    "        self.x_train = x\n",
    "        self.y_train = y\n",
    "        \n",
    "    # 模型预测方法\n",
    "    def predict(self, x):\n",
    "        # 初始化预测分类数组\n",
    "        y_pred = np.zeros((x.shape[0], 1), dtype = self.y_train.dtype)\n",
    "        \n",
    "        # 遍历输入的x 数据点, 取出每一个数据点的序号i和数据x_test\n",
    "        for i, x_test in enumerate(x):\n",
    "            # x_test 与所有训练数据计算距离\n",
    "            distances = self.dist_func(self.x_train, x_test)\n",
    "            \n",
    "            # 得到的距离按照由近到远排序,取出索引值\n",
    "            nn_index = np.argsort(distances)\n",
    "            \n",
    "            # 选取最近的k个点， 保存对应的分类类别\n",
    "            nn_y = self.y_train[ nn_index[:self.n_neighbors] ].ravel()\n",
    "            \n",
    "            # 统计类别中出现频率最高的那个，赋值给y_pred\n",
    "            y_pred[i] = np.argmax( np.bincount(nn_y) )\n",
    "            \n",
    "        return y_pred\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. 测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "预测准确率:  0.9333333333333333\n"
     ]
    }
   ],
   "source": [
    "knn = KNN(n_neighbors = 3)\n",
    "# 训练模型\n",
    "knn.fit(x_train, y_train)\n",
    "# 传入测试数据做预测\n",
    "y_pred = knn.predict(x_test)\n",
    "\n",
    "# 求出预测准确率\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "\n",
    "print(\"预测准确率: \", accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>K</th>\n",
       "      <th>距离函数</th>\n",
       "      <th>预测准确率</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>l1_distance</td>\n",
       "      <td>0.933333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>3</td>\n",
       "      <td>l1_distance</td>\n",
       "      <td>0.933333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>5</td>\n",
       "      <td>l1_distance</td>\n",
       "      <td>0.977778</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>7</td>\n",
       "      <td>l1_distance</td>\n",
       "      <td>0.955556</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>9</td>\n",
       "      <td>l1_distance</td>\n",
       "      <td>0.955556</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>1</td>\n",
       "      <td>l2_distance</td>\n",
       "      <td>0.933333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>3</td>\n",
       "      <td>l2_distance</td>\n",
       "      <td>0.933333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>5</td>\n",
       "      <td>l2_distance</td>\n",
       "      <td>0.977778</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>7</td>\n",
       "      <td>l2_distance</td>\n",
       "      <td>0.977778</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>9</td>\n",
       "      <td>l2_distance</td>\n",
       "      <td>0.977778</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   K         距离函数     预测准确率\n",
       "0  1  l1_distance  0.933333\n",
       "1  3  l1_distance  0.933333\n",
       "2  5  l1_distance  0.977778\n",
       "3  7  l1_distance  0.955556\n",
       "4  9  l1_distance  0.955556\n",
       "5  1  l2_distance  0.933333\n",
       "6  3  l2_distance  0.933333\n",
       "7  5  l2_distance  0.977778\n",
       "8  7  l2_distance  0.977778\n",
       "9  9  l2_distance  0.977778"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "knn = KNN()\n",
    "# 训练模型\n",
    "knn.fit(x_train, y_train)\n",
    "# 保存结果\n",
    "result_list = []\n",
    "\n",
    "# 针对不同的参数选取做预测\n",
    "for p in [1,2]:\n",
    "    knn.dist_func = l1_distance if p == 1 else l2_distance\n",
    "    \n",
    "    #考虑不同的k 取值\n",
    "    for k in range(1, 10, 2):\n",
    "        knn.n_neighbors = k\n",
    "        \n",
    "        # 传入测试数据做预测\n",
    "        y_pred = knn.predict(x_test)\n",
    "        # 求出预测准确率\n",
    "        accuracy = accuracy_score(y_test, y_pred)\n",
    "        result_list.append([k, 'l1_distance' if p == 1 else 'l2_distance', accuracy])\n",
    "df = pd.DataFrame(result_list, columns=['K', '距离函数', '预测准确率'])\n",
    "df"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
